Master's thesis case study 3: Bandit's with stopping¶

In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
import numpy
import torch
from adaptive_nof1 import *
from adaptive_nof1.policies import *
from adaptive_nof1.helpers import *
from adaptive_nof1.inference import *
from adaptive_nof1.metrics import *
from matplotlib import pyplot as plt
import seaborn
from adaptive_nof1.patient_explorer import show_patient_explorer
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
In [3]:
# Setup generic n-of-1 parameters
block_length = 5
max_length = 10 * block_length
number_of_actions = 2
number_of_patients = 100
In [4]:
# Scenarios
class NormalModel(Model):
    def __init__(self, patient_id, mean, variance):
        self.rng = numpy.random.default_rng(patient_id)
        self.mean = mean
        self.variance = variance
        self.patient_id = patient_id

    def multivariate_normal_distribution(debug_data):
        cov = torch.diag_embed(torch.tensor(numpy.sqrt(self.variance)))
        return torch.distributions.MultivariateNormal(torch.tensor(self.mean), cov)

    def generate_context(self, history):
        return {}

    @property
    def additional_config(self):
        return {"expectations_of_interventions": self.mean}

    @property
    def number_of_interventions(self):
        return len(self.mean)

    def observe_outcome(self, action, context):
        treatment_index = action["treatment"]
        return {"outcome": self.rng.normal(self.mean[treatment_index], numpy.sqrt(self.variance[treatment_index]))}

    def __str__(self):
        return f"NormalModel({self.mean, self.variance})"

generating_scenario_I = lambda patient_id: NormalModel(patient_id, mean=[0, 0], variance=[1,1])
generating_scenario_II = lambda patient_id: NormalModel(patient_id, mean=[1, 0], variance=[1,1])
generating_scenario_III = lambda patient_id: NormalModel(patient_id, mean=[2, 0], variance=[1,1])
In [5]:
# Inference Model
inference_model = lambda: NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)

# Stopping Time
ALPHA_STOPPING = 0.01
def alpha_stopping_time(history, context):
    model = NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)
    model.update_posterior(history, number_of_actions)
    probabilities = model.approximate_max_probabilities(number_of_actions, context)
    return 1 - max(probabilities) < ALPHA_STOPPING
In [6]:
# Policies
fixed_policy = StoppingPolicy(
    policy = BlockPolicy(
        block_length = block_length,
        internal_policy = FixedPolicy(
            number_of_actions=2,
            inference_model = inference_model(),
        ),
    ),
    stopping_time = alpha_stopping_time,
)

explore_then_commit = StoppingPolicy(
    policy= BlockPolicy(
        block_length = block_length,
        internal_policy = ExploreThenCommit(
        number_of_actions=2,
        exploration_length=4,
        block_length = block_length,
        inference_model = inference_model(),
    
        ),
    ),
    stopping_time = alpha_stopping_time,
)


thompson_sampling_policy = StoppingPolicy(
        policy = BlockPolicy(
            block_length = block_length,
            internal_policy = ThompsonSampling(
                inference_model=inference_model(),
                number_of_actions=2,
            ),
        ),
    stopping_time = alpha_stopping_time,
)

ucb_policy = StoppingPolicy(
    policy = BlockPolicy(
        block_length = block_length,
        internal_policy = UpperConfidenceBound(
            inference_model=inference_model(),
            number_of_actions=2,
            epsilon=0.05,
        ),
    ),
    stopping_time = alpha_stopping_time,
)
In [7]:
# Full crossover study
study_designs = {
    "n_patients": [number_of_patients],
    "policy": [fixed_policy, explore_then_commit, thompson_sampling_policy, ucb_policy],
    "model_from_patient_id": [
        generating_scenario_I, generating_scenario_II, generating_scenario_III,
    ]
}
configurations = generate_configuration_cross_product(study_designs)
In [8]:
ENABLE_SIMULATION = False
if ENABLE_SIMULATION:
    print("Simulation was enabled")
else:
    print("Simulation was disabled")
Simulation was enabled
In [9]:
if ENABLE_SIMULATION:
    calculated_series, config_to_simulation_data = simulate_configurations(
        configurations, max_length
    )
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
In [10]:
if ENABLE_SIMULATION:
    write_to_disk("data/2024-02-11-mt_case_study_3_data.json", [calculated_series, config_to_simulation_data])
else:
    calculated_series, config_to_simulation_data = load_from_disk("data/2024-02-11-mt_case_study_3_data.json")
In [23]:
# Todo: make the output table in a way that we chose the maximum index
def debug_data_to_torch_distribution(debug_data):
    mean = debug_data["mean"]
    # + the true variance of 1
    standard_deviation = numpy.sqrt(numpy.array(debug_data["variance"]) + 1)
    cov = torch.diag_embed(torch.tensor(standard_deviation))
    return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)

def data_to_true_distribution(data):
    mean = data.additional_config["expectations_of_interventions"]
    cov = torch.eye(len(mean))
    return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)


metrics = [
    SimpleRegretWithMean(),
    BestArmIdentification(),
    CumulativeRegret(),
    Length(),
    KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution),
]
model_mapping = {
    "NormalModel(([0, 0], [1, 1]))": "I",
    "NormalModel(([1, 0], [1, 1]))": "II",
    "NormalModel(([2, 0], [1, 1]))": "III",
}
policy_mapping = {
    "StoppingPolicy(BlockPolicy(FixedPolicy))": "Fixed",
    "StoppingPolicy(BlockPolicy(ThompsonSampling(NormalKnownVariance(0, 1, 1))))": "TS",
    "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))": "UCB",
    "StoppingPolicy(BlockPolicy(ExploreThenCommit(4,NormalKnownVariance(0, 1, 1))))": "ETC",
}

df = SeriesOfSimulationsData.score_data(
    [s["result"] for s in calculated_series], metrics, {"model": lambda x: model_mapping[x], "policy": lambda x: policy_mapping[x]}
)

df = df.reset_index(drop=True)
max_t_indices = df.groupby(["policy", "metric", "model", "patient_id"])["t"].idxmax()
filtered_df = df.iloc[max_t_indices]
filtered_df = filtered_df.reset_index(drop=True)
groupby_columns = ["model", "policy"]

pivoted_df = filtered_df.pivot(
    index=["model", "policy", "patient_id"],
    columns="metric",
    values="score",
)
table = pivoted_df.groupby(groupby_columns).agg(['mean', 'std'])

policy_ordering = ["Fixed", "ETC", "UCB", "TS"]

# Convert the 'policy' column in the MultiIndex to a Categorical type with the specified order
table = table.reset_index()
table['policy'] = pd.Categorical(table['policy'], categories=policy_ordering, ordered=True)

# Sort the DataFrame first by 'model' then by the now-ordered 'policy'
sorted_table = table.sort_values(by=['model', 'policy']).set_index(groupby_columns)[["Cumulative Regret (outcome)", "KL Divergence", "Simple Regret With Mean", "Length", "Best Arm Identification With Mean"]]
sorted_table.index.names = ["S.", "Policy"]
sorted_table
Out[23]:
metric Cumulative Regret (outcome) KL Divergence Simple Regret With Mean Length Best Arm Identification With Mean
mean std mean std mean std mean std mean std
S. Policy
I Fixed -1.015252 7.228124 0.058222 0.108264 0.0 0.000000 48.11 8.292500 0.54 0.500908
ETC -0.906465 7.043759 0.068086 0.104436 0.0 0.000000 47.87 8.621860 0.49 0.502418
UCB -0.990027 7.111509 0.093519 0.123903 0.0 0.000000 47.94 8.338568 0.52 0.502117
TS -1.03192 7.280245 0.073209 0.099991 0.0 0.000000 48.48 7.134324 0.51 0.502418
II Fixed -13.437944 8.339200 0.064499 0.055955 0.0 0.000000 24.04 13.325695 1.0 0.000000
ETC -18.781947 14.631094 0.070053 0.063430 0.0 0.000000 26.43 15.462078 1.0 0.000000
UCB -33.334714 21.230854 0.149162 0.231549 0.02 0.140705 37.12 17.352716 0.98 0.140705
TS -24.481239 17.566796 0.153106 0.242918 0.01 0.100000 33.44 15.109245 0.99 0.100000
III Fixed -12.886026 5.586778 0.181561 0.183543 0.0 0.000000 10.41 4.109843 1.0 0.000000
ETC -13.261922 6.508062 0.1809 0.182769 0.0 0.000000 10.59 4.408658 1.0 0.000000
UCB -46.015805 42.648110 0.872675 0.977566 0.28 0.697470 26.03 19.479829 0.86 0.348735
TS -42.058965 42.416344 0.899014 0.987336 0.26 0.675995 25.34 18.321054 0.87 0.337998
In [12]:
with open('mt_resources/7-stopping/01-table-part-1.tex', 'w') as file:
    str = sorted_table[["Cumulative Regret (outcome)", "KL Divergence", "Simple Regret With Mean"]].style.format(precision=1).to_latex(hrules=True)
    print(str)
    file.write(str)
\begin{tabular}{lllrlrlr}
\toprule
 & metric & \multicolumn{2}{r}{Cumulative Regret (outcome)} & \multicolumn{2}{r}{KL Divergence} & \multicolumn{2}{r}{Simple Regret With Mean} \\
 &  & mean & std & mean & std & mean & std \\
S. & Policy &  &  &  &  &  &  \\
\midrule
\multirow[c]{4}{*}{I} & Fixed & -1.0 & 7.2 & 0.1 & 0.1 & 0.0 & 0.0 \\
 & ETC & -0.9 & 7.0 & 0.1 & 0.1 & 0.0 & 0.0 \\
 & UCB & -1.0 & 7.1 & 0.1 & 0.1 & 0.0 & 0.0 \\
 & TS & -1.0 & 7.3 & 0.1 & 0.1 & 0.0 & 0.0 \\
\multirow[c]{4}{*}{II} & Fixed & -13.4 & 8.3 & 0.1 & 0.1 & 0.0 & 0.0 \\
 & ETC & -18.8 & 14.6 & 0.1 & 0.1 & 0.0 & 0.0 \\
 & UCB & -33.3 & 21.2 & 0.1 & 0.2 & 0.0 & 0.1 \\
 & TS & -24.5 & 17.6 & 0.2 & 0.2 & 0.0 & 0.1 \\
\multirow[c]{4}{*}{III} & Fixed & -12.9 & 5.6 & 0.2 & 0.2 & 0.0 & 0.0 \\
 & ETC & -13.3 & 6.5 & 0.2 & 0.2 & 0.0 & 0.0 \\
 & UCB & -46.0 & 42.6 & 0.9 & 1.0 & 0.3 & 0.7 \\
 & TS & -42.1 & 42.4 & 0.9 & 1.0 & 0.3 & 0.7 \\
\bottomrule
\end{tabular}

In [34]:
with open('mt_resources/7-stopping/01-table-part-2.tex', 'w') as file:
    str = sorted_table[["Length", "Best Arm Identification With Mean"]].style.format(precision=2).to_latex(hrules=True)
    print(str)
    file.write(str)
\begin{tabular}{lllrlr}
\toprule
 & metric & \multicolumn{2}{r}{Length} & \multicolumn{2}{r}{Best Arm Identification With Mean} \\
 &  & mean & std & mean & std \\
S. & Policy &  &  &  &  \\
\midrule
\multirow[c]{4}{*}{I} & Fixed & 48.11 & 8.29 & 0.54 & 0.50 \\
 & ETC & 47.87 & 8.62 & 0.49 & 0.50 \\
 & UCB & 47.94 & 8.34 & 0.52 & 0.50 \\
 & TS & 48.48 & 7.13 & 0.51 & 0.50 \\
\multirow[c]{4}{*}{II} & Fixed & 24.04 & 13.33 & 1.00 & 0.00 \\
 & ETC & 26.43 & 15.46 & 1.00 & 0.00 \\
 & UCB & 37.12 & 17.35 & 0.98 & 0.14 \\
 & TS & 33.44 & 15.11 & 0.99 & 0.10 \\
\multirow[c]{4}{*}{III} & Fixed & 10.41 & 4.11 & 1.00 & 0.00 \\
 & ETC & 10.59 & 4.41 & 1.00 & 0.00 \\
 & UCB & 26.03 & 19.48 & 0.86 & 0.35 \\
 & TS & 25.34 & 18.32 & 0.87 & 0.34 \\
\bottomrule
\end{tabular}

In [14]:
def rename_df(df):
    df["policy_#_metric_#_model_p"] = df["policy"].apply(lambda x: policy_mapping[x])
    df['policy'] = pd.Categorical(df['policy_#_metric_#_model_p'], categories=policy_ordering, ordered=True)
    return df

SeriesOfSimulationsData.plot_lines(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [
        CumulativeRegret(),
    ],
    legend_position=(0.02,0.3),
    process_df = rename_df,
)
plt.ylabel('Regret')
plt.savefig("mt_resources/7-stopping/01_cumulative_regret.pdf", bbox_inches="tight")
No description has been provided for this image
In [15]:
SeriesOfSimulationsData.plot_lines(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [
        SimpleRegretWithMean(),
    ],
    legend_position=(0.8,1.0),
    process_df = rename_df,
)
plt.ylabel('Simple Regret')
plt.savefig("mt_resources/7-stopping/01_simple_regret.pdf", bbox_inches="tight")
No description has been provided for this image
In [16]:
SeriesOfSimulationsData.plot_lines(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [
        KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution)
    ],
    legend_position=(0.8,1.0),
    process_df = rename_df,
)
plt.ylabel('KL Divergence')
plt.savefig("mt_resources/7-stopping/01-kl-divergence.pdf", bbox_inches="tight")
No description has been provided for this image
In [26]:
df = SeriesOfSimulationsData.score_data(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [ IsStopped() ],
)
groupby_df_sum = rename_df(df).groupby(["policy", "model", "t"]).sum()

ax = seaborn.lineplot(
    data=groupby_df_sum,
    x="t",
    y="score",
    hue="policy",
    # units="patient_id",
    #estimator=numpy.median,
    #errorbar=lambda x: (numpy.quantile(x, 0.25), numpy.quantile(x, 0.75)),
)
plt.ylabel("Number of patients")
seaborn.move_legend(ax, "upper right", title=None)
plt.savefig("mt_resources/7-stopping/01_is_stopped.pdf", bbox_inches="tight")
/var/folders/2g/v44yvb1n6sdgnp5mwbh8_qgc0000gn/T/ipykernel_95227/3657439798.py:5: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  groupby_df_sum = rename_df(df).groupby(["policy", "model", "t"]).sum()
No description has been provided for this image
In [18]:
plot_allocations_for_calculated_series(calculated_series)
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_drag' property; using the latest value
  layout_plot = gridplot(
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_scroll' property; using the latest value
  layout_plot = gridplot(
Out[18]:
In [19]:
plot_allocations_for_calculated_series([s for s in calculated_series if s["configuration"]["policy"] == "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))" and s["configuration"]["model"] == "NormalModel(([0, 0], [1, 1]))"])
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_drag' property; using the latest value
  layout_plot = gridplot(
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_scroll' property; using the latest value
  layout_plot = gridplot(
Out[19]:
In [29]:
plot_allocations_for_calculated_series([s for s in calculated_series if s["configuration"]["policy"] == "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))" and s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"])
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_drag' property; using the latest value
  layout_plot = gridplot(
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_scroll' property; using the latest value
  layout_plot = gridplot(
Out[29]:
In [30]:
show_patient_explorer(calculated_series, filter_attributes=["variance_2", "variance_1", "variance_3", "probabilities_3", "is_start_of_block", "is_stopped"])
Out[30]:
In [ ]: